{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "ec5aa9d2-9ceb-441e-b1b4-b70e3eabcab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import random\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "4cbf93d3-fa82-4926-b724-5e9b10a7b728",
   "metadata": {},
   "outputs": [],
   "source": [
    "WEIGHT = 3.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "3aa258d6-8fed-4910-8858-0f19504fdfe5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 7.],\n",
       "        [ 4.],\n",
       "        [ 2.],\n",
       "        [ 7.],\n",
       "        [ 6.],\n",
       "        [ 3.],\n",
       "        [10.],\n",
       "        [ 8.],\n",
       "        [ 6.],\n",
       "        [ 7.]], grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_list = []\n",
    "for _ in range(1000):\n",
    "    x_list.append([random.randint(1, 10)])\n",
    "\n",
    "inputs = torch.tensor(x_list, dtype=torch.float, requires_grad=True)\n",
    "inputs[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "b8a7b78b-987a-4227-acdd-0da6d8499ae9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[25.9000],\n",
       "        [14.8000],\n",
       "        [ 7.4000],\n",
       "        [25.9000],\n",
       "        [22.2000],\n",
       "        [11.1000],\n",
       "        [37.0000],\n",
       "        [29.6000],\n",
       "        [22.2000],\n",
       "        [25.9000]])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_list = []\n",
    "for x in x_list:\n",
    "    y_list.append([x[0] * WEIGHT])\n",
    "\n",
    "targets = torch.tensor(y_list, dtype=torch.float)\n",
    "targets[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "a54e1c69-e370-40a9-85aa-0f318b602447",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LNTXLinearModel(\n",
       "  (linear): Linear(in_features=1, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 线性模型\n",
    "class LNTXLinearModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LNTXLinearModel, self).__init__()\n",
    "        self.linear = nn.Linear(in_features=1, out_features=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "\n",
    "# 实例化模型和损失函数\n",
    "model = LNTXLinearModel()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "b8d0207f-8b12-4f49-a8af-b4fa7bd881f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "# 定义优化器\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "edca5a70-59fd-49bc-b223-eda9b8fc47af",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final weights: 3.6999993324279785\n"
     ]
    }
   ],
   "source": [
    "# 训练循环，迭代10次\n",
    "for epoch in range(10):\n",
    "    # 前向传播\n",
    "    predictions = model(inputs)\n",
    "    \n",
    "    # 计算损失\n",
    "    loss = loss_fn(predictions, targets)\n",
    "    \n",
    "    # 清空梯度\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # 反向传播\n",
    "    loss.backward()\n",
    "    \n",
    "    # 更新参数\n",
    "    optimizer.step()\n",
    "\n",
    "# 输出最终的权重\n",
    "print(\"Final weights:\", model.linear.weight.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "b167cce0-95a8-40fd-bb19-d0d981d67c89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11.10000228881836"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval(x_input):\n",
    "    return model(torch.tensor([float(x_input)])).item()\n",
    "    \n",
    "# 使用模型\n",
    "eval(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c63e97-a7bf-4f7d-97dd-8bc8103678dd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
